use anyhow::{Context, bail};
use cairo_lang_defs::db::DefsGroup;
use cairo_lang_defs::ids::{
    FreeFunctionId, LanguageElementId, LookupItemId, ModuleId, ModuleItemId,
    NamedLanguageElementId, SubmoduleId,
};
use cairo_lang_diagnostics::ToOption;
use cairo_lang_filesystem::ids::{CrateId, SmolStrId};
use cairo_lang_semantic::diagnostic::SemanticDiagnostics;
use cairo_lang_semantic::expr::inference::InferenceId;
use cairo_lang_semantic::expr::inference::canonic::ResultNoErrEx;
use cairo_lang_semantic::items::constant::ConstantSemantic;
use cairo_lang_semantic::items::functions::{
    ConcreteFunctionWithBodyId as SemanticConcreteFunctionWithBodyId, GenericFunctionId,
};
use cairo_lang_semantic::items::imp::ImplLongId;
use cairo_lang_semantic::items::impl_alias::ImplAliasSemantic;
use cairo_lang_semantic::items::module::ModuleSemantic;
use cairo_lang_semantic::items::us::SemanticUseEx;
use cairo_lang_semantic::resolve::{ResolvedGenericItem, Resolver};
use cairo_lang_semantic::substitution::SemanticRewriter;
use cairo_lang_sierra::ids::FunctionId;
use cairo_lang_sierra_generator::db::SierraGenGroup;
use cairo_lang_sierra_generator::replace_ids::SierraIdReplacer;
use cairo_lang_starknet_classes::keccak::starknet_keccak;
use cairo_lang_syntax::node::helpers::{GetIdentifier, PathSegmentEx, QueryAttrs};
use cairo_lang_syntax::node::{Terminal, TypedStablePtr, TypedSyntaxNode};
use cairo_lang_utils::extract_matches;
use cairo_lang_utils::ordered_hash_map::{
    OrderedHashMap, deserialize_ordered_hashmap_vec, serialize_ordered_hashmap_vec,
};
use itertools::chain;
use salsa::Database;
use serde::{Deserialize, Serialize};
use starknet_types_core::felt::Felt as Felt252;
use {cairo_lang_lowering as lowering, cairo_lang_semantic as semantic};

use crate::aliased::Aliased;
use crate::compile::{SemanticEntryPoints, extract_semantic_entrypoints};
use crate::plugin::aux_data::StarknetContractAuxData;
use crate::plugin::consts::{ABI_ATTR, ABI_ATTR_EMBED_V0_ARG};

#[cfg(test)]
#[path = "contract_test.rs"]
mod test;

/// Represents a declaration of a contract.
#[derive(Clone)]
pub struct ContractDeclaration<'db> {
    /// The id of the module that defines the contract.
    pub submodule_id: SubmoduleId<'db>,
}

impl<'db> ContractDeclaration<'db> {
    pub fn module_id(&self) -> ModuleId<'db> {
        ModuleId::Submodule(self.submodule_id)
    }
}

/// Returns the contract declaration of a given module if it is a contract module.
pub fn module_contract<'db>(
    db: &'db dyn Database,
    module_id: ModuleId<'db>,
) -> Option<ContractDeclaration<'db>> {
    let all_aux_data = module_id.module_data(db).ok()?.generated_file_aux_data(db);

    // When a module is generated by a plugin the same aux data appears in two
    // places:
    //   1. db.module_generated_file_aux_data(*original_module_id)?[k] (with k > 0).
    //   2. db.module_generated_file_aux_data(*generated_module_id)?[0].
    // We are interested in modules that the plugin acted on and not modules that were
    // created by the plugin, so we skip all_aux_data[0].
    // For example if we have
    // mod a {
    //    #[starknet::contract]
    //    mod b {
    //    }
    // }
    // Then we want lookup b inside a and not inside b.
    all_aux_data.values().skip(1).find_map(|aux_data| {
        let StarknetContractAuxData { contract_name } =
            aux_data.as_ref()?.as_any().downcast_ref()?;
        if let ModuleId::Submodule(submodule_id) = module_id {
            Some(ContractDeclaration { submodule_id })
        } else {
            unreachable!("Contract `{contract_name}` was not found.");
        }
    })
}

/// Finds the inline modules annotated as contracts in the given crate_ids and
/// returns the corresponding ContractDeclarations.
pub fn find_contracts<'db>(
    db: &'db dyn Database,
    crate_ids: &[CrateId<'db>],
) -> Vec<ContractDeclaration<'db>> {
    let mut contract_declarations = vec![];
    for crate_id in crate_ids {
        let modules = db.crate_modules(*crate_id);
        for module_id in modules.iter() {
            contract_declarations.extend(module_contract(db, *module_id));
        }
    }
    contract_declarations
}

/// Returns the ABI functions of a given contract.
/// Assumes the given module is a contract module.
pub fn get_contract_abi_functions<'db>(
    db: &'db dyn Database,
    contract: &ContractDeclaration<'db>,
    module_name: &'db str,
) -> anyhow::Result<Vec<Aliased<semantic::ConcreteFunctionWithBodyId<'db>>>> {
    let module_name = SmolStrId::from(db, module_name);
    Ok(chain!(
        get_contract_internal_module_abi_functions(db, contract, module_name)?,
        get_impl_aliases_abi_functions(db, contract, module_name)?
    )
    .collect())
}

/// Returns the ABI functions in a given internal module in the contract.
fn get_contract_internal_module_abi_functions<'db>(
    db: &'db dyn Database,
    contract: &ContractDeclaration<'db>,
    module_name: SmolStrId<'db>,
) -> anyhow::Result<Vec<Aliased<SemanticConcreteFunctionWithBodyId<'db>>>> {
    let generated_module_id = get_generated_contract_module(db, contract)?;
    let module_id = get_submodule_id(db, generated_module_id, module_name)?;
    get_module_aliased_functions(db, module_id)?
        .into_iter()
        .map(|f| f.try_map(|f| semantic::ConcreteFunctionWithBodyId::from_no_generics_free(db, f)))
        .collect::<Option<Vec<_>>>()
        .with_context(|| "Generics are not allowed in wrapper functions")
}

/// Returns the list of functions in a given module with their aliases.
/// Assumes the given module is a generated module containing `use` items pointing to wrapper ABI
/// functions.
fn get_module_aliased_functions<'db>(
    db: &'db dyn Database,
    module_id: ModuleId<'db>,
) -> anyhow::Result<Vec<Aliased<FreeFunctionId<'db>>>> {
    module_id
        .module_data(db)
        .map(|data| data.uses(db))
        .to_option()
        .with_context(|| "Failed to get external module uses.")?
        .iter()
        .map(|(use_id, leaf)| {
            if let ResolvedGenericItem::GenericFunction(GenericFunctionId::Free(function_id)) = db
                .use_resolved_item(*use_id)
                .to_option()
                .with_context(|| "Failed to fetch used function.")?
            {
                Ok(Aliased {
                    value: function_id,
                    alias: leaf.stable_ptr(db).identifier(db).to_string(db),
                })
            } else {
                bail!("Expected a free function.")
            }
        })
        .collect::<Result<Vec<_>, _>>()
}

/// Returns the abi functions of the impl aliases embedded in the given contract.
/// `module_prefix` is the prefix of the generated module name outside of the contract, the rest of
/// the name is defined by the name of the aliased impl.
fn get_impl_aliases_abi_functions<'db>(
    db: &'db dyn Database,
    contract: &ContractDeclaration<'db>,
    module_prefix: SmolStrId<'db>,
) -> anyhow::Result<Vec<Aliased<SemanticConcreteFunctionWithBodyId<'db>>>> {
    let generated_module_id = get_generated_contract_module(db, contract)?;
    let mut diagnostics = SemanticDiagnostics::default();
    let mut all_abi_functions = vec![];
    for (impl_alias_id, impl_alias) in generated_module_id
        .module_data(db)
        .to_option()
        .with_context(|| "Failed to get external module impl aliases.")?
        .impl_aliases(db)
        .iter()
    {
        if !impl_alias.has_attr_with_arg(db, ABI_ATTR, ABI_ATTR_EMBED_V0_ARG) {
            continue;
        }
        let Ok(resolved_impl) = db.impl_alias_resolved_impl(*impl_alias_id) else {
            bail!("Internal error: Failed to get impl alias solution.");
        };
        let ImplLongId::Concrete(concrete) = resolved_impl.long(db) else {
            bail!("Internal error: Solved impl alias is not an impl.");
        };
        let impl_def_id = concrete.long(db).impl_def_id;
        let impl_module = impl_def_id.parent_module(db);
        let impl_name = impl_def_id.name_identifier(db).text(db).long(db);
        let module_id = get_submodule_id(
            db,
            impl_module,
            SmolStrId::from(db, format!("{}_{impl_name}", module_prefix.long(db))),
        )?;
        let mut resolver = Resolver::new(
            db,
            impl_alias_id.parent_module(db),
            InferenceId::LookupItemDeclaration(LookupItemId::ModuleItem(ModuleItemId::ImplAlias(
                *impl_alias_id,
            ))),
        );
        let Some(last_segment) = impl_alias.impl_path(db).segments(db).elements(db).last() else {
            unreachable!("impl_path should have at least one segment");
        };
        let generic_args = last_segment.generic_args(db).unwrap_or_default();
        for abi_function in get_module_aliased_functions(db, module_id)? {
            all_abi_functions.extend(abi_function.try_map(|f| {
                let concrete_wrapper = resolver
                    .specialize_function(
                        &mut diagnostics,
                        impl_alias.stable_ptr(db).untyped(),
                        GenericFunctionId::Free(f),
                        &generic_args,
                    )
                    .to_option()?
                    .get_concrete(db)
                    .body(db)
                    .to_option()??;
                let inference = &mut resolver.inference();
                assert_eq!(
                    inference.finalize_without_reporting(),
                    Ok(()),
                    "All inferences should be solved at this point."
                );
                Some(inference.rewrite(concrete_wrapper).no_err())
            }));
        }
    }
    diagnostics
        .build()
        .expect_with_db(db, "Internal error: Inference for wrappers generics failed.");
    Ok(all_abi_functions)
}

/// Returns the generated contract module.
fn get_generated_contract_module<'db>(
    db: &'db dyn Database,
    contract: &ContractDeclaration<'db>,
) -> anyhow::Result<ModuleId<'db>> {
    let parent_module_id = contract.submodule_id.parent_module(db);
    let contract_name = contract.submodule_id.name(db);

    match db
        .module_item_by_name(parent_module_id, contract_name)
        .to_option()
        .with_context(|| "Failed to initiate a lookup in the root module.")?
    {
        Some(ModuleItemId::Submodule(generated_module_id)) => {
            Ok(ModuleId::Submodule(generated_module_id))
        }
        _ => anyhow::bail!(format!("Failed to get generated module {}.", contract_name.long(db))),
    }
}

/// Returns the module id of the submodule of a module.
fn get_submodule_id<'db>(
    db: &'db dyn Database,
    module_id: ModuleId<'db>,
    submodule_name: SmolStrId<'db>,
) -> anyhow::Result<ModuleId<'db>> {
    match db
        .module_item_by_name(module_id, submodule_name)
        .to_option()
        .with_context(|| "Failed to initiate a lookup in the {module_name} module.")?
    {
        Some(ModuleItemId::Submodule(submodule_id)) => Ok(ModuleId::Submodule(submodule_id)),
        _ => anyhow::bail!(
            "Failed to get the submodule `{}` of `{}`.",
            submodule_name.long(db),
            module_id.full_path(db)
        ),
    }
}

/// Sierra information of a contract.
#[derive(Clone, Serialize, Deserialize, PartialEq, Debug, Eq)]
pub struct ContractInfo {
    /// Sierra function of the constructor.
    pub constructor: Option<FunctionId>,
    /// Sierra functions of the external functions.
    #[serde(
        serialize_with = "serialize_ordered_hashmap_vec",
        deserialize_with = "deserialize_ordered_hashmap_vec"
    )]
    pub externals: OrderedHashMap<Felt252, FunctionId>,
    /// Sierra functions of the l1 handler functions.
    #[serde(
        serialize_with = "serialize_ordered_hashmap_vec",
        deserialize_with = "deserialize_ordered_hashmap_vec"
    )]
    pub l1_handlers: OrderedHashMap<Felt252, FunctionId>,
}

/// Returns the list of functions in a given module.
pub fn get_contracts_info<T: SierraIdReplacer>(
    db: &dyn Database,
    contracts: Vec<ContractDeclaration<'_>>,
    replacer: &T,
) -> Result<OrderedHashMap<Felt252, ContractInfo>, anyhow::Error> {
    let mut contracts_info = OrderedHashMap::default();
    for contract in contracts {
        let (class_hash, contract_info) = analyze_contract(db, &contract, replacer)?;
        contracts_info.insert(class_hash, contract_info);
    }
    Ok(contracts_info)
}

/// Analyzes a contract and returns its class hash and a list of its functions.
fn analyze_contract<'db, T: SierraIdReplacer>(
    db: &dyn Database,
    contract: &ContractDeclaration<'db>,
    replacer: &T,
) -> anyhow::Result<(Felt252, ContractInfo)> {
    // Extract class hash.
    let item = db
        .module_item_by_name(contract.module_id(), SmolStrId::from(db, "TEST_CLASS_HASH"))
        .unwrap()
        .unwrap();
    let constant_id = extract_matches!(item, ModuleItemId::Constant);
    let class_hash =
        Felt252::from(db.constant_const_value(constant_id).unwrap().long(db).to_int().unwrap());

    // Extract functions.
    let SemanticEntryPoints { external, l1_handler, constructor } =
        extract_semantic_entrypoints(db, contract)?;
    let externals =
        external.into_iter().map(|f| get_selector_and_sierra_function(db, &f, replacer)).collect();
    let l1_handlers = l1_handler
        .into_iter()
        .map(|f| get_selector_and_sierra_function(db, &f, replacer))
        .collect();
    let constructors: Vec<_> = constructor
        .into_iter()
        .map(|f| get_selector_and_sierra_function(db, &f, replacer))
        .collect();

    let contract_info = ContractInfo {
        externals,
        l1_handlers,
        constructor: constructors.into_iter().next().map(|x| x.1),
    };
    Ok((class_hash, contract_info))
}

/// Converts a function to a Sierra function.
/// Returns the selector and the Sierra function ID.
pub fn get_selector_and_sierra_function<'db, T: SierraIdReplacer>(
    db: &dyn Database,
    function_with_body: &Aliased<lowering::ids::ConcreteFunctionWithBodyId<'db>>,
    replacer: &T,
) -> (Felt252, FunctionId) {
    let function_id = function_with_body.value.function_id(db).expect("Function error.");
    let sierra_id = replacer.replace_function_id(&db.intern_sierra_function(function_id));
    let selector: Felt252 = starknet_keccak(function_with_body.alias.as_bytes()).into();
    (selector, sierra_id)
}
